{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export \n",
    "from fastai.basics import *\n",
    "from fastai.vision.core import *\n",
    "from fastai.vision.data import *\n",
    "from fastai.vision.augment import *\n",
    "from fastai.vision import models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp vision.learner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learner for the vision applications\n",
    "\n",
    "> All the functions necessary to build `Learner` suitable for transfer learning in computer vision"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The most important functions of this module are `cnn_learner` and `unet_learner`. They will help you define a `Learner` using a pretrained model. See the [vision tutorial](http://docs.fast.ai/tutorial.vision) for examples of use."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cut a pretrained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def _is_pool_type(l): return re.search(r'Pool[123]d$', l.__class__.__name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "m = nn.Sequential(nn.AdaptiveAvgPool2d(5), nn.Linear(2,3), nn.Conv2d(2,3,1), nn.MaxPool3d(5))\n",
    "test_eq([bool(_is_pool_type(m_)) for m_ in m.children()], [True,False,False,True])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default, the fastai library cuts a pretrained model at the pooling layer. This function helps detecting it. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def has_pool_type(m):\n",
    "    \"Return `True` if `m` is a pooling layer or has one in its children\"\n",
    "    if _is_pool_type(m): return True\n",
    "    for l in m.children():\n",
    "        if has_pool_type(l): return True\n",
    "    return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = nn.Sequential(nn.AdaptiveAvgPool2d(5), nn.Linear(2,3), nn.Conv2d(2,3,1), nn.MaxPool3d(5))\n",
    "assert has_pool_type(m)\n",
    "test_eq([has_pool_type(m_) for m_ in m.children()], [True,False,False,True])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def _get_first_layer(m):\n",
    "    \"Access first layer of a model\"\n",
    "    c,p,n = m,None,None  # child, parent, name\n",
    "    for n in next(m.named_parameters())[0].split('.')[:-1]:\n",
    "        p,c=c,getattr(c,n)\n",
    "    return c,p,n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _load_pretrained_weights(new_layer, previous_layer):\n",
    "    \"Load pretrained weights based on number of input channels\"\n",
    "    n_in = getattr(new_layer, 'in_channels')\n",
    "    if n_in==1:\n",
    "        # we take the sum\n",
    "        new_layer.weight.data = previous_layer.weight.data.sum(dim=1, keepdim=True)\n",
    "    elif n_in==2:\n",
    "        # we take first 2 channels + 50%\n",
    "        new_layer.weight.data = previous_layer.weight.data[:,:2] * 1.5\n",
    "    else:\n",
    "        # keep 3 channels weights and set others to null\n",
    "        new_layer.weight.data[:,:3] = previous_layer.weight.data\n",
    "        new_layer.weight.data[:,3:].zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _update_first_layer(model, n_in, pretrained):\n",
    "    \"Change first layer based on number of input channels\"\n",
    "    if n_in == 3: return\n",
    "    first_layer, parent, name = _get_first_layer(model)\n",
    "    assert isinstance(first_layer, nn.Conv2d), f'Change of input channels only supported with Conv2d, found {first_layer.__class__.__name__}'\n",
    "    assert getattr(first_layer, 'in_channels') == 3, f'Unexpected number of input channels, found {getattr(first_layer, \"in_channels\")} while expecting 3'\n",
    "    params = {attr:getattr(first_layer, attr) for attr in 'out_channels kernel_size stride padding dilation groups padding_mode'.split()}\n",
    "    params['bias'] = getattr(first_layer, 'bias') is not None\n",
    "    params['in_channels'] = n_in\n",
    "    new_layer = nn.Conv2d(**params)\n",
    "    if pretrained:\n",
    "        _load_pretrained_weights(new_layer, first_layer)\n",
    "    setattr(parent, name, new_layer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def create_body(arch, n_in=3, pretrained=True, cut=None):\n",
    "    \"Cut off the body of a typically pretrained `arch` as determined by `cut`\"\n",
    "    model = arch(pretrained=pretrained)\n",
    "    _update_first_layer(model, n_in, pretrained)\n",
    "    #cut = ifnone(cut, cnn_config(arch)['cut'])\n",
    "    if cut is None:\n",
    "        ll = list(enumerate(model.children()))\n",
    "        cut = next(i for i,o in reversed(ll) if has_pool_type(o))\n",
    "    if   isinstance(cut, int):      return nn.Sequential(*list(model.children())[:cut])\n",
    "    elif callable(cut): return cut(model)\n",
    "    else:                           raise NamedError(\"cut must be either integer or a function\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`cut` can either be an integer, in which case we cut the model at the corresponding layer, or a function, in which case, this function returns `cut(model)`. It defaults to the first layer that contains some pooling otherwise."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = lambda pretrained : nn.Sequential(nn.Conv2d(3,5,3), nn.BatchNorm2d(5), nn.AvgPool2d(1), nn.Linear(3,4))\n",
    "m = create_body(tst)\n",
    "test_eq(len(m), 2)\n",
    "\n",
    "m = create_body(tst, cut=3)\n",
    "test_eq(len(m), 3)\n",
    "\n",
    "m = create_body(tst, cut=noop)\n",
    "test_eq(len(m), 4)\n",
    "\n",
    "for n in range(1,5):    \n",
    "    m = create_body(tst, n_in=n)\n",
    "    test_eq(_get_first_layer(m)[0].in_channels, n)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Head and model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def create_head(nf, n_out, lin_ftrs=None, ps=0.5, concat_pool=True, bn_final=False, lin_first=False, y_range=None):\n",
    "    \"Model head that takes `nf` features, runs through `lin_ftrs`, and out `n_out` classes.\"\n",
    "    lin_ftrs = [nf, 512, n_out] if lin_ftrs is None else [nf] + lin_ftrs + [n_out]\n",
    "    ps = L(ps)\n",
    "    if len(ps) == 1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps\n",
    "    actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]\n",
    "    pool = AdaptiveConcatPool2d() if concat_pool else nn.AdaptiveAvgPool2d(1)\n",
    "    layers = [pool, Flatten()]\n",
    "    if lin_first: layers.append(nn.Dropout(ps.pop(0)))\n",
    "    for ni,no,p,actn in zip(lin_ftrs[:-1], lin_ftrs[1:], ps, actns):\n",
    "        layers += LinBnDrop(ni, no, bn=True, p=p, act=actn, lin_first=lin_first)\n",
    "    if lin_first: layers.append(nn.Linear(lin_ftrs[-2], n_out))\n",
    "    if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))\n",
    "    if y_range is not None: layers.append(SigmoidRange(*y_range))\n",
    "    return nn.Sequential(*layers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The head begins with fastai's `AdaptiveConcatPool2d` if `concat_pool=True` otherwise, it uses traditional average pooling. Then it uses a `Flatten` layer before going on blocks of `BatchNorm`, `Dropout` and `Linear` layers (if `lin_first=True`, those are `Linear`, `BatchNorm`, `Dropout`).\n",
    "\n",
    "Those blocks start at `nf`, then every element of `lin_ftrs` (defaults to `[512]`) and end at `n_out`. `ps` is a list of probabilities used for the dropouts (if you only pass 1, it will use half the value then that value as many times as necessary).\n",
    "\n",
    "If `bn_final=True`, a final `BatchNorm` layer is added. If `y_range` is passed, the function adds a `SigmoidRange` to that range."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): AdaptiveConcatPool2d(\n",
       "    (ap): AdaptiveAvgPool2d(output_size=1)\n",
       "    (mp): AdaptiveMaxPool2d(output_size=1)\n",
       "  )\n",
       "  (1): Flatten(full=False)\n",
       "  (2): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (3): Dropout(p=0.25, inplace=False)\n",
       "  (4): Linear(in_features=5, out_features=512, bias=False)\n",
       "  (5): ReLU(inplace=True)\n",
       "  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (7): Dropout(p=0.5, inplace=False)\n",
       "  (8): Linear(in_features=512, out_features=10, bias=False)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tst = create_head(5, 10)\n",
    "tst"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "mods = list(tst.children())\n",
    "test_eq(len(mods), 9)\n",
    "assert isinstance(mods[2], nn.BatchNorm1d)\n",
    "assert isinstance(mods[-1], nn.Linear)\n",
    "\n",
    "tst = create_head(5, 10, lin_first=True)\n",
    "mods = list(tst.children())\n",
    "test_eq(len(mods), 8)\n",
    "assert isinstance(mods[2], nn.Dropout)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.callback.hook import num_features_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#TODO: refactor, i.e. something like this?\n",
    "# class ModelSplitter():\n",
    "#     def __init__(self, idx): self.idx = idx\n",
    "#     def split(self, m): return L(m[:self.idx], m[self.idx:]).map(params)\n",
    "#     def __call__(self,): return {'cut':self.idx, 'split':self.split}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def default_split(m):\n",
    "    \"Default split of a model between body and head\"\n",
    "    return L(m[0], m[1:]).map(params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To do transfer learning, you need to pass a `splitter` to `Learner`. This should be a function taking the model and returning a collection of parameter groups, e.g. a list of list of parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _xresnet_split(m): return L(m[0][:3], m[0][3:], m[1:]).map(params)\n",
    "def  _resnet_split(m): return L(m[0][:6], m[0][6:], m[1:]).map(params)\n",
    "def _squeezenet_split(m:nn.Module): return L(m[0][0][:5], m[0][0][5:], m[1:]).map(params)\n",
    "def _densenet_split(m:nn.Module): return L(m[0][0][:7],m[0][0][7:], m[1:]).map(params)\n",
    "def _vgg_split(m:nn.Module): return L(m[0][0][:22], m[0][0][22:], m[1:]).map(params)\n",
    "def _alexnet_split(m:nn.Module): return L(m[0][0][:6], m[0][0][6:], m[1:]).map(params)\n",
    "\n",
    "_default_meta    = {'cut':None, 'split':default_split}\n",
    "_xresnet_meta    = {'cut':-4, 'split':_xresnet_split, 'stats':imagenet_stats}\n",
    "_resnet_meta     = {'cut':-2, 'split':_resnet_split, 'stats':imagenet_stats}\n",
    "_squeezenet_meta = {'cut':-1, 'split': _squeezenet_split, 'stats':imagenet_stats}\n",
    "_densenet_meta   = {'cut':-1, 'split':_densenet_split, 'stats':imagenet_stats}\n",
    "_vgg_meta        = {'cut':-2, 'split':_vgg_split, 'stats':imagenet_stats}\n",
    "_alexnet_meta    = {'cut':-2, 'split':_alexnet_split, 'stats':imagenet_stats}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "model_meta = {\n",
    "    models.xresnet.xresnet18 :{**_xresnet_meta}, models.xresnet.xresnet34: {**_xresnet_meta},\n",
    "    models.xresnet.xresnet50 :{**_xresnet_meta}, models.xresnet.xresnet101:{**_xresnet_meta},\n",
    "    models.xresnet.xresnet152:{**_xresnet_meta},\n",
    "\n",
    "    models.resnet18 :{**_resnet_meta}, models.resnet34: {**_resnet_meta},\n",
    "    models.resnet50 :{**_resnet_meta}, models.resnet101:{**_resnet_meta},\n",
    "    models.resnet152:{**_resnet_meta},\n",
    "\n",
    "    models.squeezenet1_0:{**_squeezenet_meta},\n",
    "    models.squeezenet1_1:{**_squeezenet_meta},\n",
    "\n",
    "    models.densenet121:{**_densenet_meta}, models.densenet169:{**_densenet_meta},\n",
    "    models.densenet201:{**_densenet_meta}, models.densenet161:{**_densenet_meta},\n",
    "    models.vgg11_bn:{**_vgg_meta}, models.vgg13_bn:{**_vgg_meta}, models.vgg16_bn:{**_vgg_meta}, models.vgg19_bn:{**_vgg_meta},\n",
    "    models.alexnet:{**_alexnet_meta}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(create_head)\n",
    "def create_cnn_model(arch, n_out, pretrained=True, cut=None, n_in=3, init=nn.init.kaiming_normal_, custom_head=None,\n",
    "                     concat_pool=True, **kwargs):\n",
    "    \"Create custom convnet architecture\"\n",
    "    meta = model_meta.get(arch, _default_meta)\n",
    "    body = create_body(arch, n_in, pretrained, ifnone(cut, meta['cut']))\n",
    "    if custom_head is None:\n",
    "        nf = num_features_model(nn.Sequential(*body.children())) * (2 if concat_pool else 1)\n",
    "        head = create_head(nf, n_out, concat_pool=concat_pool, **kwargs)\n",
    "    else: head = custom_head\n",
    "    model = nn.Sequential(body, head)\n",
    "    if init is not None: apply_init(model[1], init)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"create_cnn_model\" class=\"doc_header\"><code>create_cnn_model</code><a href=\"__main__.py#L2\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>create_cnn_model</code>(**`arch`**, **`n_out`**, **`pretrained`**=*`True`*, **`cut`**=*`None`*, **`n_in`**=*`3`*, **`init`**=*`kaiming_normal_`*, **`custom_head`**=*`None`*, **`concat_pool`**=*`True`*, **`lin_ftrs`**=*`None`*, **`ps`**=*`0.5`*, **`bn_final`**=*`False`*, **`lin_first`**=*`False`*, **`y_range`**=*`None`*)\n",
       "\n",
       "Create custom convnet architecture"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(create_cnn_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model is cut according to `cut` and it may be `pretrained`, in which case, the proper set of weights is downloaded then loaded. `init` is applied to the head of the model, which is either created by `create_head` (with `lin_ftrs`, `ps`, `concat_pool`, `bn_final`, `lin_first` and `y_range`) or is `custom_head`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = create_cnn_model(models.resnet18, 10, True)\n",
    "tst = create_cnn_model(models.resnet18, 10, True, n_in=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pets = DataBlock(blocks=(ImageBlock, CategoryBlock), \n",
    "                 get_items=get_image_files, \n",
    "                 splitter=RandomSplitter(),\n",
    "                 get_y=RegexLabeller(pat = r'/([^/]+)_\\d+.jpg$'))\n",
    "\n",
    "dls = pets.dataloaders(untar_data(URLs.PETS)/\"images\", item_tfms=RandomResizedCrop(300, min_scale=0.5), bs=64,\n",
    "                        batch_tfms=[*aug_transforms(size=224)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `Learner` convenience functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _add_norm(dls, meta, pretrained):\n",
    "    if not pretrained: return\n",
    "    after_batch = dls.after_batch\n",
    "    if first(o for o in after_batch.fs if isinstance(o,Normalize)): return\n",
    "    stats = meta.get('stats')\n",
    "    if stats is None: return\n",
    "    after_batch.add(Normalize.from_stats(*stats))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(create_cnn_model)\n",
    "def cnn_learner(dls, arch, normalize=True, n_out=None, pretrained=True, config=None,\n",
    "                # learner args\n",
    "                loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=None, cbs=None, metrics=None, path=None,\n",
    "                model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95),\n",
    "                # other model args\n",
    "                **kwargs):\n",
    "    \"Build a convnet style learner from `dls` and `arch`\"\n",
    "    \n",
    "    if config:\n",
    "        warnings.warn('config param is deprecated. Pass your args directly to cnn_learner.')\n",
    "        kwargs = {**config, **kwargs}\n",
    "    \n",
    "    meta = model_meta.get(arch, _default_meta)\n",
    "    if normalize: _add_norm(dls, meta, pretrained)\n",
    "    \n",
    "    if n_out is None: n_out = get_c(dls)\n",
    "    assert n_out, \"`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`\"\n",
    "    model = create_cnn_model(arch, n_out, pretrained=pretrained, **kwargs)\n",
    "    \n",
    "    splitter=ifnone(splitter, meta['split'])\n",
    "    learn = Learner(dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, cbs=cbs,\n",
    "                   metrics=metrics, path=path, model_dir=model_dir, wd=wd, wd_bn_bias=wd_bn_bias, train_bn=train_bn,\n",
    "                   moms=moms)\n",
    "    if pretrained: learn.freeze()\n",
    "    # keep track of args for loggers\n",
    "    store_attr('arch,normalize,n_out,pretrained', self=learn)\n",
    "    if kwargs: store_attr(self=learn, **kwargs)\n",
    "    return learn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model is built from `arch` using the number of final activations inferred from `dls` if possible (otherwise pass a value to `n_out`). It might be `pretrained` and the architecture is cut and split using the default metadata of the model architecture (this can be customized by passing a `cut` or a `splitter`). \n",
    "\n",
    "If `normalize` and `pretrained` are `True`, this function adds a `Normalization` transform to the `dls` (if there is not already one) using the statistics of the pretrained model. That way, you won't ever forget to normalize your data in transfer learning.\n",
    "\n",
    "All other arguments are passed to `Learner`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.PETS)\n",
    "fnames = get_image_files(path/\"images\")\n",
    "pat = r'^(.*)_\\d+.jpg$'\n",
    "dls = ImageDataLoaders.from_name_re(path, fnames, pat, item_tfms=Resize(224))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = cnn_learner(dls, models.resnet34, loss_func=CrossEntropyLossFlat(), ps=0.25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "test_eq(to_cpu(dls.after_batch[1].mean[0].squeeze()), tensor(imagenet_stats[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(models.unet.DynamicUnet.__init__)\n",
    "def create_unet_model(arch, n_out, img_size, pretrained=True, cut=None, n_in=3, **kwargs):\n",
    "    \"Create custom unet architecture\"\n",
    "    meta = model_meta.get(arch, _default_meta)\n",
    "    body = create_body(arch, n_in, pretrained, ifnone(cut, meta['cut']))    \n",
    "    model = models.unet.DynamicUnet(body, n_out, img_size, **kwargs)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h4 id=\"create_unet_model\" class=\"doc_header\"><code>create_unet_model</code><a href=\"__main__.py#L2\" class=\"source_link\" style=\"float:right\">[source]</a></h4>\n",
       "\n",
       "> <code>create_unet_model</code>(**`arch`**, **`n_out`**, **`img_size`**, **`pretrained`**=*`True`*, **`cut`**=*`None`*, **`n_in`**=*`3`*, **`blur`**=*`False`*, **`blur_final`**=*`True`*, **`self_attention`**=*`False`*, **`y_range`**=*`None`*, **`last_cross`**=*`True`*, **`bottle`**=*`False`*, **`act_cls`**=*`ReLU`*, **`init`**=*`kaiming_normal_`*, **`norm_type`**=*`None`*)\n",
       "\n",
       "Create custom unet architecture"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(create_unet_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = create_unet_model(models.resnet18, 10, (24,24), True, n_in=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(create_unet_model)\n",
    "def unet_learner(dls, arch, normalize=True, n_out=None, pretrained=True, config=None,\n",
    "                 # learner args\n",
    "                 loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=None, cbs=None, metrics=None, path=None,\n",
    "                 model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95),\n",
    "                 # other model args\n",
    "                 **kwargs):    \n",
    "    \"Build a unet learner from `dls` and `arch`\"\n",
    "    \n",
    "    if config:\n",
    "        warnings.warn('config param is deprecated. Pass your args directly to unet_learner.')\n",
    "        kwargs = {**config, **kwargs}\n",
    "    \n",
    "    meta = model_meta.get(arch, _default_meta)\n",
    "    if normalize: _add_norm(dls, meta, pretrained)\n",
    "    \n",
    "    n_out = ifnone(n_out, get_c(dls))\n",
    "    assert n_out, \"`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`\"\n",
    "    img_size = dls.one_batch()[0].shape[-2:]\n",
    "    assert img_size, \"image size could not be inferred from data\"\n",
    "    model = create_unet_model(arch, n_out, img_size, pretrained=pretrained, **kwargs)\n",
    "\n",
    "    splitter=ifnone(splitter, meta['split'])\n",
    "    learn = Learner(dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, cbs=cbs,\n",
    "                   metrics=metrics, path=path, model_dir=model_dir, wd=wd, wd_bn_bias=wd_bn_bias, train_bn=train_bn,\n",
    "                   moms=moms)\n",
    "    if pretrained: learn.freeze()\n",
    "    # keep track of args for loggers\n",
    "    store_attr('arch,normalize,n_out,pretrained', self=learn)\n",
    "    if kwargs: store_attr(self=learn, **kwargs)\n",
    "    return learn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model is built from `arch` using the number of final filters inferred from `dls` if possible (otherwise pass a value to `n_out`). It might be `pretrained` and the architecture is cut and split using the default metadata of the model architecture (this can be customized by passing a `cut` or a `splitter`). \n",
    "\n",
    "If `normalize` and `pretrained` are `True`, this function adds a `Normalization` transform to the `dls` (if there is not already one) using the statistics of the pretrained model. That way, you won't ever forget to normalize your data in transfer learning.\n",
    "\n",
    "All other arguments are passed to `Learner`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.CAMVID_TINY)\n",
    "fnames = get_image_files(path/'images')\n",
    "def label_func(x): return path/'labels'/f'{x.stem}_P{x.suffix}'\n",
    "codes = np.loadtxt(path/'codes.txt', dtype=str)\n",
    "    \n",
    "dls = SegmentationDataLoaders.from_label_func(path, fnames, label_func, codes=codes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = unet_learner(dls, models.resnet34, loss_func=CrossEntropyLossFlat(axis=1), y_range=(0,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "learn = unet_learner(dls, models.resnet34, pretrained=True, n_in=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-34-1a35cbb396b7>:12: UserWarning: config param is deprecated. Pass your args directly to unet_learner.\n",
      "  warnings.warn('config param is deprecated. Pass your args directly to unet_learner.')\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "learn = unet_learner(dls, models.resnet34, pretrained=True, config={'n_in':4})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Show functions -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def show_results(x:TensorImage, y, samples, outs, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):\n",
    "    if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, add_vert=1, figsize=figsize)\n",
    "    ctxs = show_results[object](x, y, samples, outs, ctxs=ctxs, max_n=max_n, **kwargs)\n",
    "    return ctxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def show_results(x:TensorImage, y:TensorCategory, samples, outs, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):\n",
    "    if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, add_vert=1, figsize=figsize)\n",
    "    for i in range(2):\n",
    "        ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]\n",
    "    ctxs = [r.show(ctx=c, color='green' if b==r else 'red', **kwargs)\n",
    "            for b,r,c,_ in zip(samples.itemgot(1),outs.itemgot(0),ctxs,range(max_n))]\n",
    "    return ctxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def show_results(x:TensorImage, y:(TensorMask, TensorPoint, TensorBBox), samples, outs, ctxs=None, max_n=6,\n",
    "                 nrows=None, ncols=1, figsize=None, **kwargs):\n",
    "    if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, add_vert=1, figsize=figsize, double=True,\n",
    "                                     title='Target/Prediction')\n",
    "    for i in range(2):\n",
    "        ctxs[::2] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs[::2],range(2*max_n))]\n",
    "    for o in [samples,outs]:\n",
    "        ctxs[1::2] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(o.itemgot(0),ctxs[1::2],range(2*max_n))]\n",
    "    return ctxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def show_results(x:TensorImage, y:TensorImage, samples, outs, ctxs=None, max_n=10, figsize=None, **kwargs):\n",
    "    if ctxs is None: ctxs = get_grid(3*min(len(samples), max_n), ncols=3, figsize=figsize, title='Input/Target/Prediction')\n",
    "    for i in range(2):\n",
    "        ctxs[i::3] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs[i::3],range(max_n))]\n",
    "    ctxs[2::3] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(outs.itemgot(0),ctxs[2::3],range(max_n))]\n",
    "    return ctxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def plot_top_losses(x: TensorImage, y:TensorCategory, samples, outs, raws, losses, nrows=None, ncols=None, figsize=None, **kwargs):\n",
    "    axs = get_grid(len(samples), nrows=nrows, ncols=ncols, add_vert=1, figsize=figsize, title='Prediction/Actual/Loss/Probability')\n",
    "    for ax,s,o,r,l in zip(axs, samples, outs, raws, losses):\n",
    "        s[0].show(ctx=ax, **kwargs)\n",
    "        ax.set_title(f'{o[0]}/{s[1]} / {l.item():.2f} / {r.max().item():.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@typedispatch\n",
    "def plot_top_losses(x: TensorImage, y:TensorMultiCategory, samples, outs, raws, losses, nrows=None, ncols=None, figsize=None, **kwargs):\n",
    "    axs = get_grid(len(samples), nrows=nrows, ncols=ncols, add_vert=1, figsize=figsize)\n",
    "    for i,(ax,s) in enumerate(zip(axs, samples)): s[0].show(ctx=ax, title=f'Image {i}', **kwargs)\n",
    "    rows = get_empty_df(len(samples))\n",
    "    outs = L(s[1:] + o + (TitledStr(r), TitledFloat(l.item())) for s,o,r,l in zip(samples, outs, raws, losses))\n",
    "    for i,l in enumerate([\"target\", \"predicted\", \"probabilities\", \"loss\"]):\n",
    "        rows = [b.show(ctx=r, label=l, **kwargs) for b,r in zip(outs.itemgot(i),rows)]\n",
    "    display_df(pd.DataFrame(rows))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 01a_losses.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 10b_tutorial.albumentations.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 18b_callback.preds.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.cutmix.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted Untitled.ipynb.\n",
      "Converted decorate.ipynb.\n",
      "Converted dev-setup.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted quick_start.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
